# RUN: mlir-linalg-ods-yaml-gen %s --o-ods-decl=- | FileCheck %s --check-prefix=ODS
# RUN: mlir-linalg-ods-yaml-gen %s --o-impl=- | FileCheck %s --check-prefix=IMPL

# @linalg_structured_op
# def test1(O=TensorDef(T, S.M, S.N, output=True),
#           cast=TypeFnAttrDef(default=TypeFn.cast_signed)):
#   """Title.

#   Detailed description.
#   """
#   O[D.m, D.n] = cast(T, const(42)) + cast(T, index(D.n))

--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: test1
  cpp_class_name: Test1Op
  doc: |-
    Title.

    Detailed description.
structured_op: !LinalgStructuredOpConfig
  args:
  - !LinalgOperandDefConfig
    name: O
    kind: output_tensor
    type_var: T
    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
  - !LinalgOperandDefConfig
    name: cast
    kind: type_fn_attr
    default_fn: cast_signed
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
  iterator_types:
  - parallel
  - parallel
  assignments:
  - !ScalarAssign
    arg: O
    value: !ScalarExpression
      scalar_fn:
        kind: binary
        fn_name: add
        operands:
        - !ScalarExpression
          scalar_fn:
            kind: type
            attr_name: cast
            type_var: T
            operands:
            - !ScalarExpression
              scalar_const: '42 : i64'
        - !ScalarExpression
          scalar_fn:
            kind: type
            attr_name: cast
            type_var: T
            operands:
            - !ScalarExpression
              scalar_index: 1

# ODS-LABEL:  def Test1Op : LinalgStructuredBase_Op<"test1"

#       ODS:  let summary = [{ Title. }];
#  ODS-NEXT:  let description = [{
#  ODS-NEXT:    Detailed description.
#  ODS-NEXT:  }];

#       ODS:  let arguments =
#  ODS-NEXT:    Variadic<AnyType>:$inputs,
#  ODS-NEXT:    Variadic<AnyShaped>:$outputs,
#  ODS-NEXT:    DefaultValuedOptionalAttr<TypeFnAttr, "TypeFn::cast_signed">:$cast

#       ODS:  let builders =
#       ODS:  (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
#  ODS-NEXT:       "ValueRange":$outputs,
#  ODS-NEXT:       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),

#       ODS:  (ins "TypeRange":$resultTensorTypes, "ValueRange":$inputs,
#  ODS-NEXT:       "ValueRange":$outputs, "Attribute":$cast,
#  ODS-NEXT:       CArg<"ArrayRef<NamedAttribute>", "{}">:$attributes),

#       ODS:    buildStructuredOp($_builder, $_state, resultTensorTypes,
#  ODS-NEXT:      attributes, Test1Op::getRegionBuilder())


# IMPL-LABEL:  void Test1Op::regionBuilder(ImplicitLocOpBuilder &b,
#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
#       IMPL:  TypeFn castVal = TypeFn::cast_signed;
#  IMPL-NEXT:  auto castIter = llvm::find_if(attrs, [&](const NamedAttribute &attr) {
#  IMPL-NEXT:                                return attr.getName() == "cast"; });
#  IMPL-NEXT:  if (castIter != attrs.end()) {
#  IMPL-NEXT:    if (auto attr = castIter->getValue().dyn_cast<TypeFnAttr>())
#  IMPL-NEXT:      castVal = attr.getValue();
#  IMPL-NEXT:  }

#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.constant("42 : i64");
#   IMPL-DAG:  Value [[VAL1:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL0]]);
#   IMPL-DAG:  Value [[VAL2:[a-z0-9]+]] = helper.index(1);
#   IMPL-DAG:  Value [[VAL3:[a-z0-9]+]] = helper.buildTypeFn(castVal, block.getArgument(0).getType(), [[VAL2]]);
#   IMPL-DAG:  Value [[VAL4:[a-z0-9]+]] = helper.buildBinaryFn(BinaryFn::add, [[VAL1]], [[VAL3]]);


# @linalg_structured_op
# def test2(I=TensorDef(T, S.M, S.N),
#           O=TensorDef(T, S.M, S.N, output=True),
#           strides=IndexAttrDef(S.SM, S.SN, default=[1, 2])):
#   """Title.

#   Detailed description.
#   """
#   O[D.m, D.n] = I[D.n * S.SM, D.m * S.SN]

--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: test2
  cpp_class_name: Test2Op
  doc: |-
    Title.

    Detailed description.
structured_op: !LinalgStructuredOpConfig
  args:
  - !LinalgOperandDefConfig
    name: I
    kind: input_tensor
    type_var: T
    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
  - !LinalgOperandDefConfig
    name: O
    kind: output_tensor
    type_var: T
    shape_map: affine_map<()[s0, s1, s2, s3] -> (s0, s1)>
  - !LinalgOperandDefConfig
    name: strides
    kind: index_attr
    index_attr_map: affine_map<()[s0, s1, s2, s3] -> (s2, s3)>
    default_indices:
    - 1
    - 2
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>
    - affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>
  iterator_types:
  - parallel
  - parallel
  assignments:
  - !ScalarAssign
    arg: O
    value: !ScalarExpression
      scalar_arg: I

# ODS-LABEL:  def Test2Op : LinalgStructuredBase_Op<"test2"

#       ODS:  let arguments =
#  ODS-NEXT:    Variadic<AnyType>:$inputs,
#  ODS-NEXT:    Variadic<AnyShaped>:$outputs,
#  ODS-NEXT:    DefaultValuedOptionalAttr<RankedI64ElementsAttr<[2]>
#  ODS-SAME:    "{ static_cast<int64_t>(1), static_cast<int64_t>(2) }">:$strides

#       ODS:  "Attribute":$strides
#       ODS:  $_state.addAttribute("strides", strides);

#       ODS:  bool hasDynamicIndexingMaps();
#  ODS-NEXT:  LogicalResult verifyIndexingMapRequiredAttributes();

#       IMPL:  getSymbolBindings(Test2Op self)
#       IMPL:  cst2 = self.getStrides().getValues<int64_t>()[0];
#  IMPL-NEXT:  getAffineConstantExpr(cst2, context)
#       IMPL:  cst3 = self.getStrides().getValues<int64_t>()[1];
#  IMPL-NEXT:  getAffineConstantExpr(cst3, context)

#       IMPL:  Test2Op::getIndexingMaps()
#       IMPL:  = getSymbolBindings(*this);
#       IMPL:  "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d1 * s2, d0 * s3)>"
#       IMPL:  "affine_map<(d0, d1)[s0, s1, s2, s3] -> (d0, d1)>"

#       IMPL:  Test2Op::getNumRegionArgs() { return 2; }

#       IMPL:  Test2Op::hasDynamicIndexingMaps() { return true; }
#       IMPL:  Test2Op::verifyIndexingMapRequiredAttributes()
#       IMPL:  auto attr = op->getAttrOfType<DenseElementsAttr>("strides")
#       IMPL:  "incorrect element type for index attribute 'strides'"
#       IMPL:  "incorrect shape for index attribute 'strides'"
#       IMPL:  void Test2Op::regionBuilder(ImplicitLocOpBuilder &b,
#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
#  IMPL-NEXT:    assert(2 > 0 && block.getNumArguments() == 2 &&

#       IMPL:   yields.push_back(block.getArgument(0));

# @linalg_structured_op
# def test3(value=ScalarDef(T1),
#           O=TensorDef(U, output=True)):
#   """Title.

#   Detailed description.
#   """
#   O[None] = TypeFn.cast_signed(U, value)

--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: test3
  cpp_class_name: Test3Op
  doc: |-
    Title.

    Detailed description.
structured_op: !LinalgStructuredOpConfig
  args:
  - !LinalgOperandDefConfig
    name: value
    kind: scalar
    type_var: T1
  - !LinalgOperandDefConfig
    name: O
    kind: output_tensor
    type_var: U
    shape_map: affine_map<() -> ()>
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<() -> ()>
    - affine_map<() -> ()>
  iterator_types: []
  assignments:
  - !ScalarAssign
    arg: O
    value: !ScalarExpression
      scalar_fn:
        kind: type
        fn_name: cast_signed
        type_var: U
        operands:
        - !ScalarExpression
          scalar_arg: value

#       IMPL:  Test3Op::getIteratorTypesArray() {
#  IMPL-NEXT:    int64_t rank = getRank(getDpsInitOperand(0));

#       IMPL:  Test3Op::getIndexingMaps() {
#  IMPL-NEXT:    MLIRContext *context = getContext();
#  IMPL-NEXT:    AffineMap scalarMap = AffineMap::get(getNumParallelLoops(), 0, context);
#  IMPL-NEXT:    AffineMap tensorMap = AffineMap::getMultiDimIdentityMap(

# @linalg_structured_op
# def test4(O=TensorDef(T, S.M, S.N, output=True),
#           unary_fun=UnaryFnAttrDef(default=UnaryFn.exp),
#           binary_fun=BinaryFnAttrDef(default=BinaryFn.add)):
#   """Title.

#   Detailed description.
#   """
#   O[D.m, D.n] = binary_fun(unary_fun(O[D.m, D.n]), O[D.m, D.n])

--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: test4
  cpp_class_name: Test4Op
  doc: |-
    Title.

    Detailed description.
structured_op: !LinalgStructuredOpConfig
  args:
  - !LinalgOperandDefConfig
    name: O
    kind: output_tensor
    type_var: T
    shape_map: affine_map<()[s0, s1] -> (s0, s1)>
  - !LinalgOperandDefConfig
    name: unary_fun
    kind: unary_fn_attr
    default_fn: exp
  - !LinalgOperandDefConfig
    name: binary_fun
    kind: binary_fn_attr
    default_fn: add
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<(d0, d1)[s0, s1] -> (d0, d1)>
  iterator_types:
  - parallel
  - parallel
  assignments:
  - !ScalarAssign
    arg: O
    value: !ScalarExpression
      scalar_fn:
        kind: binary
        attr_name: binary_fun
        operands:
        - !ScalarExpression
          scalar_fn:
            kind: unary
            attr_name: unary_fun
            operands:
            - !ScalarExpression
              scalar_arg: O
        - !ScalarExpression
          scalar_arg: O

# ODS-LABEL:  def Test4Op : LinalgStructuredBase_Op<"test4"

#       ODS:  let arguments =
#  ODS-NEXT:    Variadic<AnyType>:$inputs,
#  ODS-NEXT:    Variadic<AnyShaped>:$outputs,
#  ODS-NEXT:    DefaultValuedOptionalAttr<UnaryFnAttr, "UnaryFn::exp">:$unary_fun,
#  ODS-NEXT:    DefaultValuedOptionalAttr<BinaryFnAttr, "BinaryFn::add">:$binary_fun

#       ODS:    "Attribute":$unary_fun, "Attribute":$binary_fun,

#       ODS:    $_state.addAttribute("unary_fun", unary_fun)
#  ODS-NEXT:    $_state.addAttribute("binary_fun", binary_fun)

# IMPL-LABEL:  void Test4Op::regionBuilder(ImplicitLocOpBuilder &b,
#  IMPL-NEXT:    Block &block, ArrayRef<NamedAttribute> attrs)
#       IMPL:  UnaryFn unary_funVal = UnaryFn::exp
#       IMPL:  BinaryFn binary_funVal = BinaryFn::add

#       IMPL:  Value [[VAL0:[a-z0-9]+]] = helper.buildUnaryFn(unary_funVal, block.getArgument(0))
#  IMPL-NEXT:  Value [[VAL1:[a-z0-9]+]] = helper.buildBinaryFn(binary_funVal, [[VAL0]], block.getArgument(0))
#  IMPL-NEXT:  yields.push_back([[VAL1]])

# @linalg_structured_op
# def test5(value=ScalarDef(T1), O=TensorDef(U, output=True)):
#   """Title.

#   Detailed description.
#   """
#   implements(FillOpInterface)
#   defines(Canonicalizer)
#   O[None] = TypeFn.cast(U, value)

--- !LinalgOpConfig
metadata: !LinalgOpMetadata
  name: test5
  cpp_class_name: Test5Op
  doc: |-
    Title.

    Detailed description.
  implements:
  - LinalgFillOpInterface
  defines:
  - hasCanonicalizer
structured_op: !LinalgStructuredOpConfig
  args:
  - !LinalgOperandDefConfig
    name: value
    kind: scalar
    type_var: T1
  - !LinalgOperandDefConfig
    name: O
    kind: output_tensor
    type_var: U
    shape_map: affine_map<() -> ()>
  indexing_maps: !LinalgIndexingMapsConfig
    static_indexing_maps:
    - affine_map<() -> ()>
    - affine_map<() -> ()>
  iterator_types: []
  assignments:
  - !ScalarAssign
    arg: O
    value: !ScalarExpression
      scalar_fn:
        kind: type
        fn_name: cast
        type_var: U
        operands:
        - !ScalarExpression
          scalar_arg: value

# ODS-LABEL:  def Test5Op : LinalgStructuredBase_Op<"test5"
#  ODS-NEXT:  /*extraInterfaces=*/[LinalgFillOpInterface])>

#       ODS:  let hasCanonicalizer = 1;
